/*
 * author	: calvin
 * email	: calvinwilliams@163.com
 *
 * Licensed under the LGPL v2.1, see the file LICENSE in base directory.
 */

#include "cdbc.h"

#include "byteorder.h"

#include <libpq-fe.h>
#include <postgres.h>
#include <catalog/pg_type.h>

#define _DEBUG		0

#if _DEBUG
#define _TRACE(_format_,...) { printf( "TRACE - %s:%d:%s - "_format_"\n" , __FILE__,__LINE__,__FUNCTION__ , ##__VA_ARGS__ ); fflush(stdout); }
#else
#define _TRACE(_format_,...)
#endif

struct DatabaseConnection
{
	PGconn		*pgsql_conn ;
} ;

DLLEXPORT funcConnectToDatabase ConnectToDatabase ;
DLLEXPORT funcDisconnectFromDatabase DisconnectFromDatabase ;
DLLEXPORT funcExecuteSql ExecuteSql ;
DLLEXPORT funcAutoCommitTransaction AutoCommitTransaction ;
DLLEXPORT funcCommitTransaction CommitTransaction ;
DLLEXPORT funcRollbackTransaction RollbackTransaction ;

struct DatabaseConnection *ConnectToDatabase( char *db_host , int db_port , char *db_user , char *db_pass , char *db_name )
{
	struct DatabaseConnection	*db_conn = NULL ;
	char				db_port_str[ 40 + 1 ] ;
	ConnStatusType			status ;
	
	if( db_host == NULL || db_port <= 0 )
	{
		DBCSetLastErrno( CDBC_ERROR_PARAMETER );
		return NULL;
	}
	
	db_conn = (struct DatabaseConnection *)malloc( sizeof(struct DatabaseConnection) ) ;
	if( db_conn == NULL )
	{
		DBCSetLastErrno( CDBC_ERROR_ALLOC );
		return NULL;
	}
	memset( db_conn , 0x00 , sizeof(struct DatabaseConnection) );
	
	memset( db_port_str , 0x00 , sizeof(db_port_str) );
	snprintf( db_port_str , sizeof(db_port_str)-1 , "%d" , db_port );
	db_conn->pgsql_conn = PQsetdbLogin( db_host , db_port_str , NULL , NULL , db_name , db_user , db_pass ) ;
	if( db_conn->pgsql_conn == NULL )
	{
		DisconnectFromDatabase( & db_conn );
		DBCSetLastErrno( CDBC_ERROR_CONNECT );
		return NULL;
	}
	
	status = PQstatus( db_conn->pgsql_conn ) ;
	if( status != CONNECTION_OK )
	{
		DisconnectFromDatabase( & db_conn );
		DBCSetLastErrno( CDBC_ERROR_CONNECT );
		return NULL;
	}
	
	return db_conn;
}

void DisconnectFromDatabase( struct DatabaseConnection **db_conn )
{
	if( db_conn == NULL )
	{
		DBCSetLastErrno( CDBC_ERROR_PARAMETER );
		return;
	}
	
	if( (*db_conn) )
	{
		if( (*db_conn)->pgsql_conn )
		{
			PQfinish( (*db_conn)->pgsql_conn ); (*db_conn)->pgsql_conn = NULL ;
		}
		
		free( (*db_conn) ); (*db_conn) = NULL ;
	}
	
	return;
}

static enum FieldType ConvertPgsqlFieldType( Oid type )
{
	if( type == INT2OID )
		return CDBC_FIELDTYPE_INT16;
	else if( type == INT4OID )
		return CDBC_FIELDTYPE_INT32;
	else if( type == INT8OID )
		return CDBC_FIELDTYPE_INT64;
	else if( type == FLOAT4OID )
		return CDBC_FIELDTYPE_FLOAT;
	else if( type == FLOAT8OID )
		return CDBC_FIELDTYPE_DOUBLE;
	else if( type == NUMERICOID )
		return CDBC_FIELDTYPE_DECIMAL;
	else if( type == CHAROID )
		return CDBC_FIELDTYPE_CHAR;
	else if( type == VARCHAROID )
		return CDBC_FIELDTYPE_VARCHAR;
	else if( type == DATEOID )
		return CDBC_FIELDTYPE_DATE;
	else if( type == TIMEOID )
		return CDBC_FIELDTYPE_TIME;
	else if( type == TIMESTAMPOID )
		return CDBC_FIELDTYPE_TIMESTAMP;
	else
		return CDBC_FIELDTYPE_OTHER;
}

void ExecuteSql( struct DatabaseConnection *db_conn , char *sql , struct FieldBind *binds_array , int binds_array_length , int *row_count , int *col_count , struct FieldInfo **query_field_set , char ***query_result_set , int *affected_count )
{
	Oid			*pgsql_param_type = NULL ;
	char 			**pgsql_param_value = NULL ;
	int			*pgsql_param_length = NULL ;
	int			*pgsql_param_format = NULL ;
	PGresult		*pgsql_res = NULL ;
	int			binds_array_no ;
	struct FieldBind	*binds_array_offsetptr = NULL ;
	ExecStatusType		exec_status_type ;
	int			pgsql_col_count , pgsql_col_index ;
	int			pgsql_row_count , pgsql_row_index ;
	int			pgsql_affected_count ;
	size_t			alloc_size ;
	size_t			set_index ;
	struct FieldInfo	*pgsql_query_field_set = NULL ;
	char			**pgsql_query_result_set = NULL ;
	char			*fname = NULL ;
	char			*fvalue = NULL ;
	int			nret = 0 ;
	
	DBCFreeSqlResult( query_field_set , query_result_set );
	
	if( binds_array && binds_array_length > 0 )
	{
		pgsql_param_type = (Oid*)malloc( sizeof(Oid) * binds_array_length ) ;
		if( pgsql_param_type == NULL )
		{
			_TRACE( "alloc failed , errno[%d]" , errno )
			DBCSetLastErrno( CDBC_ERROR_BIND );
			DBCSetLastNativeErrno( 0 );
			DBCSetLastNativeError( "" );
			return;
		}
		memset( pgsql_param_type , 0x00 , sizeof(Oid) * binds_array_length );
		
		pgsql_param_value = (char**)malloc( sizeof(char*) * binds_array_length ) ;
		if( pgsql_param_value == NULL )
		{
			_TRACE( "alloc failed , errno[%d]" , errno )
			free( pgsql_param_type );
			DBCSetLastErrno( CDBC_ERROR_BIND );
			DBCSetLastNativeErrno( 0 );
			DBCSetLastNativeError( "" );
			return;
		}
		memset( pgsql_param_value , 0x00 , sizeof(char*) * binds_array_length );
		
		pgsql_param_length = (int*)malloc( sizeof(int) * binds_array_length ) ;
		if( pgsql_param_length == NULL )
		{
			_TRACE( "alloc failed , errno[%d]" , errno )
			free( pgsql_param_type );
			free( pgsql_param_value );
			DBCSetLastErrno( CDBC_ERROR_BIND );
			DBCSetLastNativeErrno( 0 );
			DBCSetLastNativeError( "" );
			return;
		}
		memset( pgsql_param_length , 0x00 , sizeof(int) * binds_array_length );
		
		pgsql_param_format = (int*)malloc( sizeof(int) * binds_array_length ) ;
		if( pgsql_param_format == NULL )
		{
			_TRACE( "alloc failed , errno[%d]" , errno )
			free( pgsql_param_type );
			free( pgsql_param_value );
			free( pgsql_param_length );
			DBCSetLastErrno( CDBC_ERROR_BIND );
			DBCSetLastNativeErrno( 0 );
			DBCSetLastNativeError( "" );
			return;
		}
		memset( pgsql_param_format , 0x00 , sizeof(int) * binds_array_length );
		
		for( binds_array_no = 0 , binds_array_offsetptr = binds_array ; binds_array_no < binds_array_length ; binds_array_no++ , binds_array_offsetptr++ )
		{
			if( binds_array_offsetptr->buffer_type == CDBC_FIELDTYPE_INT16 )
			{
				uint16_t	u16 ;
				u16 = (uint16_t)*(int16_t*)(binds_array_offsetptr->buffer) ;
				u16 = HTON16( u16 ) ;
				binds_array_offsetptr->buffer_alloced = (char*)malloc( sizeof(uint16_t) ) ;
				if( binds_array_offsetptr->buffer_alloced == NULL )
				{
					_TRACE( "alloc failed , errno[%d]" , errno )
					goto _GOTO_ERROR_RETURN;
				}
				*(uint16_t*)(binds_array_offsetptr->buffer_alloced) = u16 ;
				pgsql_param_type[binds_array_no] = INT2OID ;
				pgsql_param_value[binds_array_no] = binds_array_offsetptr->buffer_alloced ;
				pgsql_param_length[binds_array_no] = sizeof(int16_t) ;
				pgsql_param_format[binds_array_no] = 1 ;
				_TRACE( "bind sql param - buffer_type[%d]->[%d] buffer[%d][%"PRIi16"]" , binds_array_offsetptr->buffer_type,pgsql_param_type[binds_array_no] , pgsql_param_length[binds_array_no] , *(int16_t*)(binds_array_offsetptr->buffer) )
			}
			else if( binds_array_offsetptr->buffer_type == CDBC_FIELDTYPE_INT32 )
			{
				uint32_t	u32 ;
				u32 = (uint32_t)*(int32_t*)(binds_array_offsetptr->buffer) ;
				u32 = HTON32( u32 ) ;
				binds_array_offsetptr->buffer_alloced = (char*)malloc( sizeof(uint32_t) ) ;
				if( binds_array_offsetptr->buffer_alloced == NULL )
				{
					_TRACE( "alloc failed , errno[%d]" , errno )
					goto _GOTO_ERROR_RETURN;
				}
				*(uint32_t*)(binds_array_offsetptr->buffer_alloced) = u32 ;
				pgsql_param_type[binds_array_no] = INT4OID ;
				pgsql_param_value[binds_array_no] = binds_array_offsetptr->buffer_alloced ;
				pgsql_param_length[binds_array_no] = sizeof(int32_t) ;
				pgsql_param_format[binds_array_no] = 1 ;
				_TRACE( "bind sql param - buffer_type[%d]->[%d] buffer[%d][%"PRIi32"]" , binds_array_offsetptr->buffer_type,pgsql_param_type[binds_array_no] , pgsql_param_length[binds_array_no],*(int32_t*)(binds_array_offsetptr->buffer) )
			}
			else if( binds_array_offsetptr->buffer_type == CDBC_FIELDTYPE_INT64 )
			{
				uint64_t	u64 ;
				u64 = (uint64_t)*(int64_t*)(binds_array_offsetptr->buffer) ;
				u64 = HTON64( u64 ) ;
				binds_array_offsetptr->buffer_alloced = (char*)malloc( sizeof(uint64_t) ) ;
				if( binds_array_offsetptr->buffer_alloced == NULL )
				{
					_TRACE( "alloc failed , errno[%d]" , errno )
					goto _GOTO_ERROR_RETURN;
				}
				*(uint64_t*)(binds_array_offsetptr->buffer_alloced) = u64 ;
				pgsql_param_type[binds_array_no] = INT8OID ;
				pgsql_param_value[binds_array_no] = binds_array_offsetptr->buffer_alloced ;
				pgsql_param_length[binds_array_no] = sizeof(int64_t) ;
				pgsql_param_format[binds_array_no] = 1 ;
				_TRACE( "bind sql param - buffer_type[%d]->[%d] buffer[%d][%"PRIi64"]" , binds_array_offsetptr->buffer_type,pgsql_param_type[binds_array_no] , pgsql_param_length[binds_array_no] , *(int64_t*)(binds_array_offsetptr->buffer) )
			}
			else if( binds_array_offsetptr->buffer_type == CDBC_FIELDTYPE_FLOAT )
			{
				nret = asprintf( & (binds_array_offsetptr->buffer_alloced) , "%f" , *(float*)(binds_array_offsetptr->buffer) ) ;
				if( nret == -1 )
				{
					_TRACE( "asprintf failed , errno[%d]" , errno )
					goto _GOTO_ERROR_RETURN;
				}
				pgsql_param_type[binds_array_no] = FLOAT4OID ;
				pgsql_param_value[binds_array_no] = binds_array_offsetptr->buffer_alloced ;
				pgsql_param_length[binds_array_no] = 0 ;
				pgsql_param_format[binds_array_no] = 0 ;
				_TRACE( "bind sql param - buffer_type[%d]->[%d] buffer[%s]" , binds_array_offsetptr->buffer_type,pgsql_param_type[binds_array_no] , pgsql_param_value[binds_array_no] )
			}
			else if( binds_array_offsetptr->buffer_type == CDBC_FIELDTYPE_DOUBLE )
			{
				nret = asprintf( & (binds_array_offsetptr->buffer_alloced) , "%lf" , *(double*)(binds_array_offsetptr->buffer) ) ;
				if( nret == -1 )
				{
					_TRACE( "asprintf failed , errno[%d]" , errno )
					goto _GOTO_ERROR_RETURN;
				}
				pgsql_param_type[binds_array_no] = FLOAT8OID ;
				pgsql_param_value[binds_array_no] = binds_array_offsetptr->buffer_alloced ;
				pgsql_param_length[binds_array_no] = 0 ;
				pgsql_param_format[binds_array_no] = 0 ;
				_TRACE( "bind sql param - buffer_type[%d]->[%d] buffer[%s]" , binds_array_offsetptr->buffer_type,pgsql_param_type[binds_array_no] , pgsql_param_value[binds_array_no] )
			}
			else if( binds_array_offsetptr->buffer_type == CDBC_FIELDTYPE_DECIMAL )
			{
				nret = asprintf( & (binds_array_offsetptr->buffer_alloced) , "%lf" , *(double*)(binds_array_offsetptr->buffer) ) ;
				if( nret == -1 )
				{
					_TRACE( "asprintf failed , errno[%d]" , errno )
					goto _GOTO_ERROR_RETURN;
				}
				pgsql_param_type[binds_array_no] = NUMERICOID ;
				pgsql_param_value[binds_array_no] = binds_array_offsetptr->buffer_alloced ;
				pgsql_param_length[binds_array_no] = 0 ;
				pgsql_param_format[binds_array_no] = 0 ;
				_TRACE( "bind sql param - buffer_type[%d]->[%d] buffer[%s]" , binds_array_offsetptr->buffer_type,pgsql_param_type[binds_array_no] , pgsql_param_value[binds_array_no] )
			}
			else if( binds_array_offsetptr->buffer_type == CDBC_FIELDTYPE_CHAR )
			{
				pgsql_param_type[binds_array_no] = CHAROID ;
				pgsql_param_value[binds_array_no] = binds_array_offsetptr->buffer ;
				pgsql_param_length[binds_array_no] = binds_array_offsetptr->buffer_length ;
				pgsql_param_format[binds_array_no] = 0 ;
				_TRACE( "bind sql param - buffer_type[%d]->[%d] buffer[%d][%.*s]" , binds_array_offsetptr->buffer_type,pgsql_param_type[binds_array_no] , pgsql_param_length[binds_array_no] , pgsql_param_length[binds_array_no],pgsql_param_value[binds_array_no] )
			}
			else if( binds_array_offsetptr->buffer_type == CDBC_FIELDTYPE_VARCHAR )
			{
				pgsql_param_type[binds_array_no] = VARCHAROID ;
				pgsql_param_value[binds_array_no] = binds_array_offsetptr->buffer ;
				pgsql_param_length[binds_array_no] = binds_array_offsetptr->buffer_length ;
				pgsql_param_format[binds_array_no] = 0 ;
				_TRACE( "bind sql param - buffer_type[%d]->[%d] buffer[%d][%.*s]" , binds_array_offsetptr->buffer_type,pgsql_param_type[binds_array_no] , pgsql_param_length[binds_array_no] , pgsql_param_length[binds_array_no],pgsql_param_value[binds_array_no] )
			}
			else if( binds_array_offsetptr->buffer_type == CDBC_FIELDTYPE_DATETIME )
			{
				struct tm	*p_tm = (struct tm *)(binds_array_offsetptr->buffer) ;
				nret = asprintf( & (binds_array_offsetptr->buffer_alloced) , "%04d-%02d-%02d %02d:%02d:%02d" , p_tm->tm_year+1900 , p_tm->tm_mon+1 , p_tm->tm_mday , p_tm->tm_hour , p_tm->tm_min , p_tm->tm_sec ) ;
				if( nret == -1 )
				{
					_TRACE( "asprintf failed , errno[%d]" , errno )
					goto _GOTO_ERROR_RETURN;
				}
				pgsql_param_type[binds_array_no] = TIMESTAMPOID ;
				pgsql_param_value[binds_array_no] = binds_array_offsetptr->buffer_alloced ;
				pgsql_param_length[binds_array_no] = 0 ;
				pgsql_param_format[binds_array_no] = 0 ;
				_TRACE( "bind sql param - buffer_type[%d]->[%d] buffer[%s]" , binds_array_offsetptr->buffer_type,pgsql_param_type[binds_array_no] , pgsql_param_value[binds_array_no] )
			}
			else
			{
				_TRACE( "unknow cdbc_type[%d]" , binds_array_offsetptr->buffer_type )
_GOTO_ERROR_RETURN :
				free( pgsql_param_type );
				free( pgsql_param_value );
				free( pgsql_param_length );
				free( pgsql_param_format );
				DBCSetLastErrno( CDBC_ERROR_BIND );
				DBCSetLastNativeErrno( 0 );
				DBCSetLastNativeError( "" );
				return;
			}
		}
		
		pgsql_res = PQexecParams( db_conn->pgsql_conn , sql , binds_array_length , pgsql_param_type , (const char * const*)pgsql_param_value , pgsql_param_length , pgsql_param_format , 0 ) ;
		_TRACE( "PQexecParams return[%p]" , pgsql_res )
	}
	else
	{
		pgsql_res = PQexec( db_conn->pgsql_conn , sql ) ;
		_TRACE( "PQexec return[%p]" , pgsql_res )
	}
	if( binds_array && binds_array_length > 0 )
	{
		free( pgsql_param_type );
		free( pgsql_param_value );
		free( pgsql_param_length );
		free( pgsql_param_format );
	}
	if( pgsql_res == NULL )
	{
		DBCSetLastErrno( CDBC_ERROR_QUERY );
		DBCSetLastNativeErrno( 0 );
		DBCSetLastNativeError( "" );
		return;
	}
	
	exec_status_type = PQresultStatus( pgsql_res ) ;
	_TRACE( "PQresultStatus return[%d]" , (int)exec_status_type )
	if( exec_status_type == PGRES_COMMAND_OK )
	{
		pgsql_row_count = 0 ;
		pgsql_col_count = 0 ;
		pgsql_query_field_set = NULL ;
		pgsql_query_result_set = NULL ;
		pgsql_affected_count = atoi( PQcmdTuples( pgsql_res ) ) ;
		
		PQclear( pgsql_res );
	}
	else if( exec_status_type == PGRES_TUPLES_OK )
	{
		pgsql_row_count = PQntuples( pgsql_res ) ;
		_TRACE( "pgsql_row_count[%d]" , pgsql_row_count )
		pgsql_col_count = PQnfields( pgsql_res ) ;
		_TRACE( "pgsql_col_count[%d]" , pgsql_col_count )
		pgsql_affected_count = 0 ;
		
		alloc_size = sizeof(struct FieldInfo) * (pgsql_col_count+1) ;
		pgsql_query_field_set = (struct FieldInfo *)malloc( alloc_size ) ;
		if( pgsql_query_field_set == NULL )
		{
			DBCSetLastErrno( CDBC_ERROR_ALLOC );
			DBCSetLastNativeErrno( 0 );
			DBCSetLastNativeError( "" );
			PQclear( pgsql_res );
			return;
		}
		memset( pgsql_query_field_set , 0x00 , alloc_size );
		pgsql_query_field_set[pgsql_col_count].field_length = pgsql_row_count ;
		
		for( pgsql_col_index = 0 ; pgsql_col_index < pgsql_col_count ; pgsql_col_index++ )
		{
			fname = PQfname(pgsql_res,pgsql_col_index) ;
			pgsql_query_field_set[pgsql_col_index].field_name = strdup( fname ) ;
			if( pgsql_query_field_set[pgsql_col_index].field_name == NULL )
			{
				DBCFreeSqlResult( & pgsql_query_field_set , & pgsql_query_result_set );
				DBCSetLastErrno( CDBC_ERROR_ALLOC );
				DBCSetLastNativeErrno( 0 );
				DBCSetLastNativeError( "" );
				PQclear( pgsql_res );
				return;
			}
			pgsql_query_field_set[pgsql_col_index].field_type = ConvertPgsqlFieldType( PQftype(pgsql_res,pgsql_col_index) ) ;
			if( pgsql_query_field_set[pgsql_col_index].field_type == CDBC_FIELDTYPE_INVALID )
			{
				DBCFreeSqlResult( & pgsql_query_field_set , & pgsql_query_result_set );
				DBCSetLastErrno( CDBC_ERROR_FIELD_TYPE_NOT_SUPPORT );
				DBCSetLastNativeErrno( 0 );
				DBCSetLastNativeError( "" );
				PQclear( pgsql_res );
				return;
			}
			pgsql_query_field_set[pgsql_col_index].field_length = PQfsize(pgsql_res,pgsql_col_index) ;
			pgsql_query_field_set[pgsql_col_index].field_decimal_length = PQfmod(pgsql_res,pgsql_col_index) ;
		}
		
		alloc_size = sizeof(char*) * pgsql_row_count * pgsql_col_count ;
		pgsql_query_result_set = (char**)malloc( alloc_size ) ;
		if( pgsql_query_result_set == NULL )
		{
			DBCFreeSqlResult( & pgsql_query_field_set , & pgsql_query_result_set );
			DBCSetLastErrno( CDBC_ERROR_ALLOC );
			DBCSetLastNativeErrno( 0 );
			DBCSetLastNativeError( "" );
			PQclear( pgsql_res );
			PQclear( pgsql_res );
			return;
		}
		memset( pgsql_query_result_set , 0x00 , alloc_size );
		
		set_index = 0 ;
		for( pgsql_row_index = 0 ; pgsql_row_index < pgsql_row_count ; pgsql_row_index++ )
		{
			for( pgsql_col_index = 0 ; pgsql_col_index < pgsql_col_count ; pgsql_col_index++ , set_index++ )
			{
				fvalue = PQgetvalue(pgsql_res,pgsql_row_index,pgsql_col_index) ;
				if( ! PQgetisnull(pgsql_res,pgsql_row_index,pgsql_col_index) )
					pgsql_query_result_set[set_index] = strndup( fvalue , PQgetlength(pgsql_res,pgsql_row_index,pgsql_col_index) ) ;
				else
					pgsql_query_result_set[set_index] = NULL ;
			}
		}
		
		PQclear( pgsql_res );
	}
	else
	{
		DBCSetLastErrno( CDBC_ERROR_QUERY );
		DBCSetLastNativeErrno( (int)PQresultStatus(pgsql_res) );
		DBCSetLastNativeError( PQresultErrorMessage(pgsql_res) );
		PQclear( pgsql_res );
		return;
	}
	
	if( row_count )
		(*row_count) = pgsql_row_count ;
	if( col_count )
		(*col_count) = pgsql_col_count ;
	if( query_field_set )
		(*query_field_set) = pgsql_query_field_set ;
	if( query_result_set )
		(*query_result_set) = pgsql_query_result_set ;
	if( affected_count )
		(*affected_count) = pgsql_affected_count ;
	return;
}

void BeginTransaction( struct DatabaseConnection *db_conn )
{
	ExecuteSql( db_conn , "BEGIN" , NULL , 0 , NULL , NULL , NULL , NULL , NULL );
	if( DBCGetLastErrno() )
		DBCSetLastErrno( CDBC_ERROR_BEGINTRANSACTION );
	
	return;
}

void CommitTransaction( struct DatabaseConnection *db_conn )
{
	ExecuteSql( db_conn , "COMMIT" , NULL , 0 , NULL , NULL , NULL , NULL , NULL );
	if( DBCGetLastErrno() )
		DBCSetLastErrno( CDBC_ERROR_COMMITTRANSACTION );
	
	return;
}

void RollbackTransaction( struct DatabaseConnection *db_conn )
{
	ExecuteSql( db_conn , "ROLLBACK" , NULL , 0 , NULL , NULL , NULL , NULL , NULL );
	if( DBCGetLastErrno() )
		DBCSetLastErrno( CDBC_ERROR_ROLLBACKTRANSACTION );
	
	return;
}

